import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import MNIST, CIFAR10
from torchvision import transforms
from tensorboardX import SummaryWriter
from data_loader import get_data
from model import get_model
from apex import amp
from torch.cuda.amp import autocast, GradScaler

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def calculate_accuracy(loader, model):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = torch.max(output, 1)
            correct += (predicted == target).sum().item()
            total += target.size(0)
    return correct / total

def main(batch_size: int, train_way: str, data_name: str = 'CIFAR10', epoch: int = 20):
    writer = SummaryWriter(f'./scalar-{data_name}/scalar-{batch_size}-{train_way}')
    train_loader, test_loader = get_data(batch_size, data_name)
    model = get_model(data_name).to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.CrossEntropyLoss()

    if train_way == 'apex':
        model, optimizer = amp.initialize(model, optimizer, opt_level="O1", loss_scale=2040)
    elif train_way == 'amp':
        scaler = GradScaler()

    for epoch in range(epoch):
        model.train()
        for i, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = loss_fn(output, target)
            optimizer.zero_grad()

            if train_way == 'apex':
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                    optimizer.step()
            elif train_way == 'amp':
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            elif train_way == 'fp32':
                loss.backward()
                optimizer.step()

            if i % 100 == 0:
                print(f'Epoch: {epoch}, Batch: {i}, Loss: {loss.item()}')

        train_accuracy = calculate_accuracy(train_loader, model)
        test_accuracy = calculate_accuracy(test_loader, model)
        writer.add_scalar('Accuracy/train', train_accuracy, epoch)
        writer.add_scalar('Accuracy/test', test_accuracy, epoch)

        if torch.cuda.is_available():
            memory_allocated = torch.cuda.memory_allocated(device) / 1024 ** 2
            memory_reserved = torch.cuda.memory_reserved(device) / 1024 ** 2
            writer.add_scalar('Memory/Allocated_MB', memory_allocated, epoch)
            writer.add_scalar('Memory/Reserved_MB', memory_reserved, epoch)

        print(f'Epoch: {epoch}, Train Accuracy: {train_accuracy:.4f}, Test Accuracy: {test_accuracy:.4f}')

if __name__ == '__main__':
    batch_size = 256
    train_way = 'fp32'
    data_name = 'CIFAR10'
    main(batch_size, train_way, data_name)